{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b494ed88",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "#全局变量\n",
    "hub_token = open('/root/hub_token.txt').read().strip()\n",
    "repo_id = 'lansinuote/diffusion.3.dream_booth'\n",
    "push_to_hub = True\n",
    "checkpoint = 'CompVis/stable-diffusion-v1-4'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "315ad307",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration lansinuote--diffusion.3.dream_booth-4344a7d7501be7f1\n",
      "Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.3.dream_booth-4344a7d7501be7f1/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(Dataset({\n",
       "     features: ['image', 'text'],\n",
       "     num_rows: 5\n",
       " }),\n",
       " {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2469x2558>,\n",
       "  'text': 'a photo of little dog'})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import Dataset, DatasetDict, load_dataset\n",
    "import torchvision\n",
    "import PIL.Image\n",
    "\n",
    "\n",
    "def get_dataset():\n",
    "    images = [{\n",
    "        'image': PIL.Image.open('images/%d.jpeg' % i),\n",
    "        'text': 'a photo of little dog',\n",
    "    } for i in range(5)]\n",
    "\n",
    "    dataset = Dataset.from_list(images)\n",
    "\n",
    "    return DatasetDict({'train': dataset})\n",
    "\n",
    "\n",
    "if push_to_hub:\n",
    "    dataset = get_dataset()\n",
    "    dataset.push_to_hub(repo_id=repo_id, token=hub_token)\n",
    "\n",
    "#直接使用我处理好的数据集\n",
    "dataset = load_dataset(path=repo_id, split='train')\n",
    "\n",
    "dataset, dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2af94061",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Dataset({\n",
       "     features: ['image', 'text'],\n",
       "     num_rows: 5\n",
       " }),\n",
       " {'pixel_values': tensor([[[ 0.7725,  0.7804,  0.7804,  ...,  0.7647,  0.7725,  0.7647],\n",
       "           [ 0.7725,  0.7882,  0.7725,  ...,  0.7647,  0.7647,  0.7647],\n",
       "           [ 0.7804,  0.7804,  0.7725,  ...,  0.7647,  0.7804,  0.7647],\n",
       "           ...,\n",
       "           [ 0.7098,  0.7176,  0.7176,  ...,  0.6941,  0.7098,  0.7020],\n",
       "           [ 0.7098,  0.7255,  0.7255,  ...,  0.6941,  0.6863,  0.7098],\n",
       "           [ 0.7020,  0.7176,  0.7176,  ...,  0.6941,  0.7020,  0.7098]],\n",
       "  \n",
       "          [[-0.1529, -0.1451, -0.1451,  ..., -0.1373, -0.1373, -0.1373],\n",
       "           [-0.1529, -0.1451, -0.1529,  ..., -0.1451, -0.1529, -0.1373],\n",
       "           [-0.1373, -0.1451, -0.1608,  ..., -0.1529, -0.1451, -0.1373],\n",
       "           ...,\n",
       "           [-0.2549, -0.2471, -0.2471,  ..., -0.2471, -0.2314, -0.2392],\n",
       "           [-0.2627, -0.2471, -0.2471,  ..., -0.2471, -0.2549, -0.2314],\n",
       "           [-0.2706, -0.2549, -0.2471,  ..., -0.2471, -0.2392, -0.2314]],\n",
       "  \n",
       "          [[-0.6392, -0.6314, -0.6314,  ..., -0.6314, -0.6314, -0.6392],\n",
       "           [-0.6392, -0.6235, -0.6392,  ..., -0.6392, -0.6471, -0.6392],\n",
       "           [-0.6235, -0.6314, -0.6471,  ..., -0.6471, -0.6314, -0.6392],\n",
       "           ...,\n",
       "           [-0.6627, -0.6549, -0.6549,  ..., -0.6549, -0.6392, -0.6549],\n",
       "           [-0.6706, -0.6549, -0.6471,  ..., -0.6549, -0.6627, -0.6471],\n",
       "           [-0.6706, -0.6549, -0.6471,  ..., -0.6549, -0.6549, -0.6471]]]),\n",
       "  'input_ids': tensor([49406,   320,  1125,   539,  1274,  1929, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407]),\n",
       "  'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0])})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torchvision\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "#数据增强\n",
    "compose = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize(\n",
    "        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),\n",
    "    torchvision.transforms.RandomCrop(512),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize([0.5], [0.5]),\n",
    "])\n",
    "\n",
    "#文字编码\n",
    "tokenizer = AutoTokenizer.from_pretrained(checkpoint,\n",
    "                                          subfolder='tokenizer',\n",
    "                                          use_fast=False)\n",
    "\n",
    "\n",
    "def f(data):\n",
    "    pixel_values = compose(data['image'][0]).unsqueeze(dim=0)\n",
    "\n",
    "    #全程都是一句话,其实只需要编码一次即可\n",
    "    #77 = tokenizer.model_max_length\n",
    "    tokens = tokenizer(data['text'][0],\n",
    "                       truncation=True,\n",
    "                       padding='max_length',\n",
    "                       max_length=77,\n",
    "                       return_tensors='pt')\n",
    "\n",
    "    return {\n",
    "        'pixel_values': pixel_values,\n",
    "        'input_ids': tokens.input_ids,\n",
    "        'attention_mask': tokens.attention_mask\n",
    "    }\n",
    "\n",
    "\n",
    "dataset = dataset.with_transform(f)\n",
    "\n",
    "dataset, dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c07c2a12",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5,\n",
       " {'pixel_values': tensor([[[[ 0.7647,  0.7490,  0.7490,  ...,  0.7569,  0.7412,  0.7490],\n",
       "            [ 0.7490,  0.7647,  0.7412,  ...,  0.7412,  0.7490,  0.7333],\n",
       "            [ 0.7647,  0.7569,  0.7569,  ...,  0.7333,  0.7333,  0.7412],\n",
       "            ...,\n",
       "            [ 0.6863,  0.6863,  0.6627,  ...,  0.6784,  0.6392,  0.6784],\n",
       "            [ 0.6863,  0.6784,  0.6706,  ...,  0.6627,  0.6627,  0.6627],\n",
       "            [ 0.6863,  0.6549,  0.6627,  ...,  0.6706,  0.6706,  0.6706]],\n",
       "  \n",
       "           [[-0.1608, -0.1765, -0.1765,  ..., -0.1843, -0.1922, -0.1765],\n",
       "            [-0.1765, -0.1608, -0.1843,  ..., -0.1843, -0.1765, -0.1922],\n",
       "            [-0.1608, -0.1686, -0.1686,  ..., -0.1922, -0.1843, -0.1922],\n",
       "            ...,\n",
       "            [-0.2549, -0.2549, -0.2863,  ..., -0.2392, -0.2706, -0.2314],\n",
       "            [-0.2627, -0.2706, -0.2784,  ..., -0.2549, -0.2471, -0.2471],\n",
       "            [-0.2627, -0.2941, -0.2863,  ..., -0.2392, -0.2392, -0.2392]],\n",
       "  \n",
       "           [[-0.6392, -0.6549, -0.6549,  ..., -0.6314, -0.6392, -0.6392],\n",
       "            [-0.6549, -0.6392, -0.6627,  ..., -0.6392, -0.6314, -0.6471],\n",
       "            [-0.6392, -0.6471, -0.6471,  ..., -0.6549, -0.6471, -0.6471],\n",
       "            ...,\n",
       "            [-0.6314, -0.6392, -0.6549,  ..., -0.6392, -0.6706, -0.6314],\n",
       "            [-0.6392, -0.6471, -0.6392,  ..., -0.6549, -0.6471, -0.6471],\n",
       "            [-0.6314, -0.6549, -0.6392,  ..., -0.6392, -0.6392, -0.6392]]]]),\n",
       "  'input_ids': tensor([[49406,   320,  1125,   539,  1274,  1929, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407]]),\n",
       "  'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "           0, 0, 0, 0, 0]])})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loader = torch.utils.data.DataLoader(dataset,\n",
    "                                     batch_size=1,\n",
    "                                     shuffle=True,\n",
    "                                     collate_fn=None)\n",
    "\n",
    "len(loader), next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e0459d36",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder 12306.048\n",
      "vae 8365.3863\n",
      "unet 85952.0964\n"
     ]
    }
   ],
   "source": [
    "from transformers.models.clip.modeling_clip import CLIPTextModel\n",
    "from diffusers import AutoencoderKL, UNet2DConditionModel\n",
    "\n",
    "encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')\n",
    "vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')\n",
    "unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')\n",
    "\n",
    "vae.requires_grad_(False)\n",
    "encoder.requires_grad_(False)\n",
    "\n",
    "\n",
    "def print_model_size(name, model):\n",
    "    print(name, sum(i.numel() for i in model.parameters()) / 10000)\n",
    "\n",
    "\n",
    "print_model_size('encoder', encoder)\n",
    "print_model_size('vae', vae)\n",
    "print_model_size('unet', unet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c6f4b455",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DDPMScheduler {\n",
       "   \"_class_name\": \"DDPMScheduler\",\n",
       "   \"_diffusers_version\": \"0.15.1\",\n",
       "   \"beta_end\": 0.012,\n",
       "   \"beta_schedule\": \"scaled_linear\",\n",
       "   \"beta_start\": 0.00085,\n",
       "   \"clip_sample\": false,\n",
       "   \"clip_sample_range\": 1.0,\n",
       "   \"dynamic_thresholding_ratio\": 0.995,\n",
       "   \"num_train_timesteps\": 1000,\n",
       "   \"prediction_type\": \"epsilon\",\n",
       "   \"sample_max_value\": 1.0,\n",
       "   \"set_alpha_to_one\": false,\n",
       "   \"skip_prk_steps\": true,\n",
       "   \"steps_offset\": 1,\n",
       "   \"thresholding\": false,\n",
       "   \"trained_betas\": null,\n",
       "   \"variance_type\": \"fixed_small\"\n",
       " },\n",
       " AdamW (\n",
       " Parameter Group 0\n",
       "     amsgrad: False\n",
       "     betas: (0.9, 0.999)\n",
       "     capturable: False\n",
       "     eps: 1e-08\n",
       "     foreach: None\n",
       "     lr: 5e-06\n",
       "     maximize: False\n",
       "     weight_decay: 0.01\n",
       " ),\n",
       " MSELoss())"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from diffusers import DDPMScheduler\n",
    "\n",
    "scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')\n",
    "\n",
    "optimizer = torch.optim.AdamW(unet.parameters(),\n",
    "                              lr=5e-6,\n",
    "                              betas=(0.9, 0.999),\n",
    "                              weight_decay=0.01,\n",
    "                              eps=1e-8)\n",
    "\n",
    "criterion = torch.nn.MSELoss()\n",
    "\n",
    "scheduler, optimizer, criterion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b3effea6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0311, grad_fn=<MseLossBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_loss(data):\n",
    "    #编码文字,由于encoder不训练,其实这一步也可以只运算一次\n",
    "    #[1, 77] -> [1, 77, 768]\n",
    "    out_encoder = encoder(input_ids=data['input_ids'],\n",
    "                          attention_mask=data['attention_mask'])[0]\n",
    "\n",
    "    #vae计算特征图\n",
    "    #[1, 3, 512, 512] -> [1, 4, 64, 64]\n",
    "    out_vae = vae.encode(data['pixel_values']).latent_dist.sample()\n",
    "\n",
    "    #0.18215 = vae.config.scaling_factor\n",
    "    out_vae = out_vae * 0.18215\n",
    "\n",
    "    #随机噪声\n",
    "    #[1, 4, 64, 64]\n",
    "    noise = torch.randn_like(out_vae)\n",
    "\n",
    "    #随机噪声步\n",
    "    #1000 = scheduler.config.num_train_timesteps\n",
    "    #1 = b\n",
    "    noise_step = torch.randint(0, 1000, (1, ),\n",
    "                               device=data['input_ids'].device).long()\n",
    "\n",
    "    #添加噪声\n",
    "    #[1, 4, 64, 64]\n",
    "    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)\n",
    "\n",
    "    #unet从噪声图中把噪声计算出来\n",
    "    #[1, 4, 64, 64],[1, 77, 768] -> [1, 4, 64, 64]\n",
    "    out_unet = unet(out_vae_noise, noise_step, out_encoder).sample\n",
    "\n",
    "    return criterion(out_unet, noise)\n",
    "\n",
    "\n",
    "get_loss({\n",
    "    'pixel_values': torch.randn(1, 3, 512, 512),\n",
    "    'input_ids': torch.ones(1, 77).long(),\n",
    "    'attention_mask': torch.ones(1, 77).long()\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "13a8c245",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.15169205842539668\n",
      "10 5.945965459337458\n",
      "20 3.8129762182943523\n",
      "30 3.295396795729175\n",
      "40 3.654667090624571\n",
      "50 6.476580709917471\n",
      "60 4.5216363104991615\n",
      "70 3.3012475325958803\n",
      "80 3.4042327782372013\n",
      "90 4.213736498611979\n",
      "100 4.350477983767632\n",
      "110 4.34992079436779\n",
      "120 3.7817602755967528\n",
      "130 3.6707470717374235\n",
      "140 2.6561128611210734\n",
      "150 1.8198281603399664\n",
      "160 2.432444425066933\n",
      "170 3.2916298899799585\n",
      "180 3.0428609114605933\n",
      "190 3.1719170592259616\n"
     ]
    }
   ],
   "source": [
    "from diffusers import DiffusionPipeline\n",
    "from huggingface_hub import Repository, create_repo\n",
    "\n",
    "\n",
    "def train():\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    unet.to(device)\n",
    "    vae.to(device)\n",
    "    encoder.to(device)\n",
    "    unet.train()\n",
    "\n",
    "    loss_sum = 0\n",
    "    for epoch in range(200):\n",
    "        for i, data in enumerate(loader):\n",
    "            for k in data.keys():\n",
    "                data[k] = data[k].to(device)\n",
    "\n",
    "            loss = get_loss(data)\n",
    "\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            loss_sum += loss.item()\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            print(epoch, loss_sum)\n",
    "            loss_sum = 0\n",
    "\n",
    "    DiffusionPipeline.from_pretrained(\n",
    "        checkpoint, unet=unet, text_encoder=encoder).save_pretrained('./save')\n",
    "\n",
    "\n",
    "if push_to_hub:\n",
    "    create_repo(repo_id, exist_ok=True, token=hub_token)\n",
    "    repo = Repository('./save', clone_from=repo_id, token=hub_token)\n",
    "    train()\n",
    "    repo.push_to_hub()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
